In [ ]:
import os
from os import path
from astropy.io import fits
import astropy.units as u
from astropy.table import Table
from astropy.constants import G

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import h5py
from sqlalchemy import func
from scipy.optimize import root

from twoface.config import TWOFACE_CACHE_PATH
from twoface.db import (db_connect, AllStar, AllVisit, AllVisitToAllStar, RedClump,
                        StarResult, Status, JokerRun, initialize_db)

from thejoker import JokerParams, TheJoker, JokerSamples
from twoface.sample_prior import make_prior_cache
from twoface.io import load_samples
from twoface.plot import plot_data_orbits

from scipy.misc import logsumexp

In [ ]:
with h5py.File('../cache/apogee-jitter.hdf5') as f:
    print(len(f.keys()))

In [ ]:
TWOFACE_CACHE_PATH = path.abspath('../cache/')

In [ ]:
Session, _ = db_connect(path.join(TWOFACE_CACHE_PATH, 'apogee.sqlite'))
session = Session()

In [ ]:
# load the run parameters:
run = session.query(JokerRun).filter(JokerRun.name == 'apogee-jitter').one()

# load the posterior samples:
samples_file = path.join(TWOFACE_CACHE_PATH, '{0}.hdf5'.format(run.name))

In [ ]:
def ln_normal(x, mu, std):
    return -0.5*( ((x-mu) / std)**2 + np.log(2*np.pi*std**2))

def ln_normal_mixture(x, amp, mu, std):
    n_components = len(amp)
    
    lls = []
    for j in range(n_components):
        lls.append(ln_normal(x, mu[j], std[j]) + np.log(amp[j]))
    
    return logsumexp(lls, axis=0)

# test against (slower) scipy implementation:
from scipy.stats import norm
derp = np.random.uniform(-2, 2, size=100)
pars = np.random.uniform(1E-3, 10, size=2)
assert np.allclose(norm.logpdf(derp, loc=pars[0], scale=pars[1]),
                   ln_normal(derp, *pars))

assert np.allclose(ln_normal_mixture(derp, [1.], [pars[0]], [pars[1]]),
                   ln_normal(derp, *pars))

In [ ]:
x = np.linspace(-10, 10, 1024)

plt.plot(x, np.exp(ln_normal_mixture(x, [0.2, 0.8], [-4, 4], [0.5, 1])), 
         marker='')
plt.axvline(-4)
plt.axvline(4)

Load data by getting particular stars:


In [ ]:
# The aspcapflag bitmask removes STAR_WARN
# The starflag bitmask removes SUSPECT_BROAD_LINES 
# The logg cut remove TRGB stars - too much intrinsic jitter
stars = session.query(AllStar).join(StarResult, JokerRun, Status, AllVisitToAllStar, AllVisit)\
                              .filter(Status.id > 0)\
                              .filter(JokerRun.name == 'apogee-jitter')\
                              .filter(AllStar.aspcapflag.op('&')(2**np.array([7, 23])) == 0)\
                              .filter(AllStar.starflag.op('&')(2**np.array([17])) == 0)\
                              .filter(AllStar.logg > 2)\
                              .group_by(AllStar.apstar_id)\
                              .having(func.count(AllVisit.id) >= 10)\
                              .all()
#                              .limit(1024).all()
print(len(stars))

In [ ]:
K_n = []
apogee_ids = []
with h5py.File(samples_file) as f:
    # Values that aren't filled get set to nan
    N = len(stars)
    y_nk = np.full((N, 256), np.nan)
    
    for n, key in enumerate([s.apogee_id for s in stars]):
        K = len(f[key]['jitter'])
        
        s = f[key]['jitter'][:] * 1000. # km/s to m/s
        y_nk[n,:K] = np.log(s**2)
        K_n.append(K)
        apogee_ids.append(key)

K_n = np.array(K_n)
apogee_ids = np.array(apogee_ids)

# for nulling out the probability for non-existing samples
mask = np.zeros_like(y_nk)
mask[y_nk == 9999.] = -np.inf

In [ ]:
plt.hist(K_n)
plt.yscale('log')
plt.xlabel('$K_n$')

In [ ]:
for apogee_id in apogee_ids[(K_n < 10)][:20]:
# for apogee_id in apogee_ids[(K_n > 10) & (K_n < 100)][:20]:
# for apogee_id in apogee_ids[(K_n > 100) & (K_n > 10)][:20]:
    star = session.query(AllStar).filter(AllStar.apogee_id == apogee_id).limit(1).one()
    data = star.apogeervdata(clean=True)

    samples = JokerSamples(trend_cls=VelocityTrend1, **load_samples(samples_file, apogee_id))
    
    fig, axes = plt.subplots(2, 1, figsize=(10, 10), sharex=True)
    
    fig = plot_data_orbits(data, samples, ax=axes[0])
    axes[0].set_title(r'$s_{{\rm max}} = ${0:.3f}'.format(samples['jitter'].max().to(u.m/u.s)), 
                      fontsize=18)
    
    # residuals
    ax = axes[1]
    for label, j in zip(['max($P$)', 'min($P$)'], [samples['P'].argmax(), samples['P'].argmin()]):
        this_samples = samples[j]

        trend_samples = dict()
        for k in samples.trend_cls.parameters:
            trend_samples[k] = this_samples.pop(k)
        trend = samples.trend_cls(**trend_samples)
        orbit = RVOrbit(trend=trend, **this_samples)

        ax.errorbar(data.t.mjd, (data.rv - orbit.generate_rv_curve(data.t)).to(u.km/u.s).value, 
                    data.stddev.to(u.km/u.s).value,
                    linestyle='none', marker='o', label=label)
        
    ax.set_ylabel('residuals [{0:latex_inline}]'.format(u.km/u.s))
    ax.set_xlabel('BMJD')
    ax.axhline(0.)
    ax.legend(loc='best', fontsize=16)
    
    fig.savefig("../plots/1-nsamples <10/{0}.png".format(apogee_id), dpi=200)
    # fig.savefig("../plots/2-nsamples 10–100/{0}.png".format(apogee_id), dpi=200)
    # fig.savefig("../plots/3-nsamples >100/{0}.png".format(apogee_id), dpi=200)
    plt.close('all')

Load data by getting a random batch of some size:


In [ ]:
# K_n = []
# apogee_ids = []
# with h5py.File(samples_file) as f:
#     # Only load 10000 stars for now
#     N = 2000
#     # N = len(f.keys())
    
#     # Values that aren't filled get set to nan
#     y_nk = np.full((N, 128), np.nan)
    
#     for n,key in enumerate(f):
#         K = len(f[key]['jitter'])
        
#         s = f[key]['jitter'][:] * 1000. # km/s to m/s
#         y_nk[n,:K] = np.log(s**2)
#         K_n.append(K)
#         apogee_ids.append(key)
        
#         if n >= (N-1): 
#             break
            
#         elif n % 1000 == 0:
#             print(n)    

# K_n = np.array(K_n)
# apogee_ids = np.array(apogee_ids)

# # for nulling out the probability for non-existing samples
# mask = np.zeros_like(y_nk)
# mask[y_nk == 9999.] = -np.inf

Re-compute value of the interim prior at the position of the samples


In [ ]:
ln_p0 = ln_normal(y_nk, float(run.jitter_mean), float(run.jitter_stddev))

Check the posterior samples:


In [ ]:
bins = np.linspace(-8, 20, 32)
plt.hist(np.ravel(y_nk)[np.isfinite(np.ravel(y_nk))], bins=bins, normed=True, alpha=0.6, label='all samples');
plt.hist(np.nanmedian(y_nk, axis=1), bins=bins, normed=True, alpha=0.6, label='median over $k$');
plt.legend(loc='upper left', fontsize=16)
plt.xlabel(r'$y = \ln s^2$')

In [ ]:
teff = [s.teff for s in stars]
logg = [s.logg for s in stars]
snr = [s.snr for s in stars]

fig, axes = plt.subplots(2, 2, figsize=(7,7.5), sharex='col', sharey='row')

style = dict(c=np.nanmax(y_nk, axis=1), marker='o', s=8, alpha=0.65,
             vmin=2, vmax=13, linewidth=0)
c = axes[0,0].scatter(teff, snr, **style)
axes[1,0].scatter(teff, logg, **style)
axes[1,1].scatter(snr, logg, **style)

axes[0,0].set_xlim(6000, 3500)
axes[0,0].set_yscale('log')
axes[1,1].set_xscale('log')
axes[1,0].set_ylim(4, 0)

axes[0,0].set_ylabel('SNR')
axes[1,0].set_xlabel('Teff')
axes[1,0].set_ylabel('log$g$')
axes[1,1].set_xlabel('SNR')

fig.tight_layout()

axes[0,1].set_visible(False)

fig.subplots_adjust(left=0.1, right=0.95, top=0.9)

cax = fig.add_axes([0.1, 0.92, 0.85, 0.025])
cb = fig.colorbar(c, cax=cax, orientation='horizontal')
cb.set_label(r'${\rm med}_k\left(y_{nk}\right)$', labelpad=10)
cb.ax.xaxis.tick_top()
cb.ax.xaxis.set_label_position('top')
cb.set_clim(style['vmin'], style['vmax'])

In [ ]:
fig, ax = plt.subplots(1, 1, figsize=(6,5))

style = dict(marker='o', s=8, alpha=0.25, linewidth=0)
c = ax.scatter(snr, np.nanmax(y_nk, axis=1), **style)
ax.set_xscale('log')
ax.set_ylim(2, 15)

ax.set_xlabel('SNR')
ax.set_ylabel(r'${\rm med}_k\left(y_{nk}\right)$')

fig.tight_layout()

The star with the largest value of the smallest $y$


In [ ]:
minmax = []
# need the loop because some stars have less than 128 samples
for i, K in zip(range(y_nk.shape[0]), K_n):
    minmax.append(y_nk[i,:K].min())
i = np.argmax(minmax)

print(y_nk[i, :K_n[i]].min(), '{0:.2f} m/s'.format(np.sqrt(np.exp(y_nk[i, :K_n[i]].min()))), apogee_ids[i])
print(i)

star = session.query(AllStar).filter(AllStar.apogee_id == apogee_ids[i]).limit(1).one()
data = star.apogeervdata(clean=True)

with h5py.File(samples_file) as f:
    samples = JokerSamples.from_hdf5(f[apogee_ids[i]])
_ = plot_data_orbits(data, samples)

# residuals?
for j in range(len(samples)):
    orbit = samples.get_orbit(j)
    
    fig, ax = plt.subplots(1,1)
    ax.errorbar(data.t.mjd, (data.rv - orbit.radial_velocity(data.t)).to(u.km/u.s).value, 
                data.stddev.to(u.km/u.s).value,
                linestyle='none', marker='o')
    ax.set_ylabel('residuals [{0:latex_inline}]'.format(u.km/u.s))
    ax.set_xlabel('BMJD')
    
    break

The star with the smallest value of the largest $y$


In [ ]:
minmax = []
# need the loop because some stars have less than 128 samples
for i, K in zip(range(y_nk.shape[0]), K_n):
    minmax.append(y_nk[i,:K].max())
i = np.argmin(minmax)

print(y_nk[i, :K_n[i]].min(), '{0:.2f} m/s'.format(np.sqrt(np.exp(y_nk[i, :K_n[i]].min()))), apogee_ids[i])
print(i)

star = session.query(AllStar).filter(AllStar.apogee_id == apogee_ids[i]).limit(1).one()
data = star.apogeervdata(clean=True)

with h5py.File(samples_file) as f:
    samples = JokerSamples.from_hdf5(f[apogee_ids[i]])
_ = plot_data_orbits(data, samples)

# residuals?
for j in range(len(samples)):
    orbit = samples.get_orbit(j)
    
    fig, ax = plt.subplots(1,1)
    ax.errorbar(data.t.mjd, (data.rv - orbit.radial_velocity(data.t)).to(u.km/u.s).value, 
                data.stddev.to(u.km/u.s).value,
                linestyle='none', marker='o')
    ax.set_ylabel('residuals [{0:latex_inline}]'.format(u.km/u.s))
    ax.set_xlabel('BMJD')
    
    break

Hierarchical inference of jitter parameter distribution


In [ ]:
x = np.random.random(size=1000)
%timeit ln_normal_mixture(x, [0.2, 0.8], [1, 10], [1, 5])
%timeit ln_normal(x, 0.2, 0.8)

In [ ]:
class Model:
    
    def __init__(self, y_nk, K_n, ln_p0, n_components=1):
        self.y = y_nk
        self.K = K_n
        self.ln_p0 = ln_p0
        self.n_components = int(n_components)
        
        self.ln_norm_func = ln_normal
        if self.n_components > 1:
            self.ln_norm_func = ln_normal_mixture

    def ln_likelihood(self, **kwargs):
        """ Original, single Gaussian implementation """
        delta_ln_prior = self.ln_norm_func(self.y, **kwargs) - self.ln_p0
        delta_ln_prior[np.isnan(delta_ln_prior)] = -np.inf
        return logsumexp(delta_ln_prior, axis=1) - np.log(self.K)
    
    def ln_prior(self, **kwargs):
        lp = 0.
        
        amp = kwargs.get('amp', None)
        if amp is not None:
            amp = np.array(amp)
            if amp.sum() > 1:
                return -np.inf
            
            if np.any(amp < 0):
                return -np.inf
        
        # enforce ordering of the means
        if not np.allclose(np.argsort(kwargs['mu']), np.arange(self.n_components)):
            return -np.inf
        
        # 1/sigma prior
        lp += -np.sum(np.log(kwargs['std'])) 
        
        return lp
    
    def unpack_pars(self, pars):
        # TODO:
        if self.n_components == 1:
            mu, std = pars
            return dict(mu=mu, std=std)
            
        else:
            amp = np.concatenate((pars[:self.n_components-1], [1-sum(pars[:self.n_components-1])]))
            mu = pars[self.n_components-1:2*self.n_components-1]
            std = pars[2*self.n_components-1:]
            return dict(amp=amp, mu=mu, std=std)
    
    def pack_pars(self, mu, std, amp=None):
        pass

    def ln_prob(self, pars_vec):
        pars_kw = self.unpack_pars(pars_vec)
        
        lp = self.ln_prior(**pars_kw)
        if not np.isfinite(lp):
            return -np.inf

        ll_n = self.ln_likelihood(**pars_kw)
        if not np.all(np.isfinite(ll_n)):
            return -np.inf

        return np.sum(ll_n)
    
    def __call__(self, p):
        return self.ln_prob(p)

In [ ]:
# slc = (slice(0,3),) # single
# slc = np.array([512,777])# + list(range(100))) # the two minmax stars above
slc = (slice(None),) # all
# slc = np.array([225, 139])

mm = Model(y_nk[slc], K_n[slc], ln_p0[slc], n_components=1)
mm([-2, 4.]), mm([2, 4.])

In [ ]:
bins = np.linspace(-5, 18, 55)

_n_sub = y_nk[slc].shape[0]
for _n in range(min(_n_sub, 8)):
    plt.hist(y_nk[slc][_n][np.isfinite(y_nk[slc][_n])], bins=bins, 
             alpha=0.5, label='star {0}'.format(_n))

plt.legend(loc='best')
    
vals = np.linspace(bins.min(), bins.max(), 1000)
# lls = ln_normal_mixture(vals, [0.2, 0.8], [0, 1.], [6., 2.])
# plt.plot(vals, np.exp(lls))

In [ ]:
mm = Model(y_nk[slc], K_n[slc], ln_p0[slc])

# Single-component likelihood
sigma_y = 2.
# sigma_y = np.std(y_nk[slc].ravel())

lls = []
vals = np.linspace(-5, 15, 128)
for val in vals:
    lls.append(mm([val, sigma_y]).sum())
    
fig, axes = plt.subplots(1, 2, figsize=(12,5), sharex=True)

axes[0].plot(vals, lls)
axes[0].set_ylabel(r'$\ln p(\{D_n\}|\alpha)$')
axes[1].plot(vals, np.exp(lls - np.max(lls)))
axes[1].set_ylabel(r'$p(\{D_n\}|\alpha)$')

# axes[1].axvline(np.mean(y_nk[slc].ravel()))

axes[0].set_xlabel(r'$\mu_y$')
axes[1].set_xlabel(r'$\mu_y$')

axes[0].xaxis.set_ticks(np.arange(vals.min(), vals.max()+1, 2))

fig.tight_layout()

In [ ]:
# Mixture model
mmix = Model(y_nk[slc], K_n[slc], ln_p0[slc], 
             n_components=2)

lls = []
vals = np.linspace(-5, 15, 128)
for val in vals:
    lls.append(mmix([0.8, val, 10, 2, 2]))
    
fig, axes = plt.subplots(1, 2, figsize=(12,5), sharex=True)

axes[0].plot(vals, lls)
axes[0].set_ylabel(r'$\ln p(\{D_n\}|\alpha)$')
axes[1].plot(vals, np.exp(lls - np.max(lls)))
axes[1].set_ylabel(r'$p(\{D_n\}|\alpha)$')

# axes[1].axvline(np.mean(y_nk[slc].ravel()))

axes[0].set_xlabel(r'$\mu_y$')
axes[1].set_xlabel(r'$\mu_y$')

axes[0].xaxis.set_ticks(np.arange(vals.min(), vals.max()+1, 2))

fig.tight_layout()


In [ ]:
mmix = Model(y_nk[slc], K_n[slc], ln_p0[slc], 
             n_components=1)

In [ ]:
mu_grid = np.linspace(-10, 20, 27)
# var_grid = np.linspace(0.1, 10, 25)**2
std_grid = np.logspace(-3, 1.5, 25)
mu_grid, std_grid = np.meshgrid(mu_grid, std_grid)

probs = np.array([mm([m, v]) 
                  for (m, v) in zip(mu_grid.ravel(), std_grid.ravel())])

In [ ]:
probs.min(), probs.max()

In [ ]:
mu_grid.ravel()[probs.argmax()], std_grid.ravel()[probs.argmax()]

In [ ]:
plt.figure(figsize=(6,5))

plt.pcolormesh(mu_grid, std_grid,
               probs.reshape(mu_grid.shape),
               cmap='Blues', vmin=-1000, vmax=probs.max())
# plt.pcolormesh(mu_grid, std_grid,
#                np.exp(probs.reshape(mu_grid.shape)),
#                cmap='Blues')

plt.yscale('log')
plt.colorbar()
plt.xlabel(r'$\mu_y$')
plt.ylabel(r'$\sigma_y$')


In [ ]:
from scipy.optimize import minimize

In [ ]:
mmix = Model(y_nk[slc], K_n[slc], ln_p0[slc], 
             n_components=1)

In [ ]:
# p0 = [0.8, 7, 10, 2, 2]
p0 = [10., 2]
mmix(p0)

In [ ]:
res = minimize(lambda *args: -mmix(*args), x0=p0)

In [ ]:
res.x

In [ ]:
y = np.linspace(-10, 20, 256)

min_pars = mmix.unpack_pars(res.x)
ll = mmix.ln_norm_func(y, **min_pars)

fig,axes = plt.subplots(1, 2, figsize=(12,5))

axes[0].plot(y, np.exp(ll), marker='')
axes[0].set_xlim(-10, 20)
axes[0].set_xlabel(r'$y=\ln\left(\frac{s}{1\,{\rm m}\,{\rm s}^{-1}} \right)^2$')

s = np.sqrt(np.exp(y))
axes[1].plot(s, np.exp(ll) * 2/s, marker='')
axes[1].set_xlim(-0.1, 400)
axes[1].set_xlabel('jitter, $s$ [{0:latex_inline}]'.format(u.m/u.s))

fig.savefig()


In [ ]:
import emcee

In [ ]:
mmix = Model(y_nk[slc], K_n[slc], ln_p0[slc], 
             n_components=2)

In [ ]:
ndim = 5
nwalkers = 8*ndim
p0 = np.random.normal([0.7, 7, 10, 2, 2], [1E-3]*ndim, size=(nwalkers, ndim))

for pp in p0:
    assert np.all(np.isfinite(mmix(pp)))

In [ ]:
mmix([0.8, 7, 10, 2, 2])

In [ ]:
sampler = emcee.EnsembleSampler(nwalkers, dim=ndim, lnpostfn=mmix)
pos,*_ = sampler.run_mcmc(p0, 1024)
# sampler.reset()
# _ = sampler.run_mcmc(pos, 512)

In [ ]:
sampler.chain.shape

In [ ]:
for dim in range(sampler.dim):
    plt.figure()
    for walker in sampler.chain[...,dim]:
        plt.plot(walker, marker='', linestyle='-', color='k', 
                 alpha=0.2, drawstyle='steps-mid')

In [ ]:
samples = np.vstack((sampler.chain[:,500::8]))
med_pars = mmix.unpack_pars(np.median(samples, axis=0))
med_pars

In [ ]:
y = np.linspace(-10, 20, 256)

ll = mmix.ln_norm_func(y, **med_pars)

fig,axes = plt.subplots(1, 2, figsize=(12,5), sharey=True)

axes[0].plot(y, np.exp(ll), marker='')
axes[0].set_xlim(-10, 20)
axes[0].set_xlabel(r'$y=\ln\left(\frac{s}{1\,{\rm m}\,{\rm s}^{-1}} \right)^2$')

s = np.sqrt(np.exp(y))
axes[1].plot(s, np.exp(ll) * 2/s, marker='')
axes[1].set_xlim(-10, 500)
axes[1].set_xlabel('jitter, $s$ [{0:latex_inline}]'.format(u.m/u.s))

In [ ]: